{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gensim.models.keyedvectors import KeyedVectors\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load in the models\n",
    "\n",
    "These are the models created in the `make_embeddings.ipynb` notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m4a_model = KeyedVectors.load_word2vec_format('../data/word2vec_m4a_clean.txt')\n",
    "lds_model = KeyedVectors.load_word2vec_format('../data/word2vec_lds_clean.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Look at words that are \"similar\" to each other in the space\n",
    "\n",
    "For a given word, identify the worlds that are most similar to it, and return them as a dataframe. We will use this to compare the embeddings created by the models for the different subreddits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_similar_words(models, s, n=10):\n",
    "    '''\n",
    "    Get the n most similar words to str from the models.\n",
    "    The models should be a list of dictionaries, in the format\n",
    "    [{'model_name': MODEL_NAME, 'model': gensim.models.keyedvectors.KeyedVectors object}].\n",
    "    '''\n",
    "    candidates = []\n",
    "    for model in models:\n",
    "        candidates += [x[0] for x in model['model'].most_similar(s, topn=n)]\n",
    "    candidates = set(candidates)\n",
    "\n",
    "    similarities = []\n",
    "    for candidate in candidates:\n",
    "        if len(candidate) > 30:\n",
    "            continue\n",
    "        for model in models:\n",
    "            try:\n",
    "                similarities.append({'original_word': s, 'similar_word': candidate, 'model_name': model['model_name'], 'similarity': model['model'].similarity(s, candidate)})\n",
    "            except KeyError:\n",
    "                pass\n",
    "    \n",
    "    similarities_df = pd.DataFrame(similarities)\n",
    "    similarities_df.loc[similarities_df['model_name'] == 'r/Masks4All', 'similarity'] = 0 - similarities_df.loc[similarities_df['model_name'] == 'r/Masks4All', 'similarity']\n",
    "    similarities_df = similarities_df.sort_values(by='similarity', ascending=False)\n",
    "    return similarities_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_figure(df):\n",
    "    original_word = df.iloc[0]['original_word']\n",
    "    # Create a figure\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    ax = sns.barplot(data=df, y='similar_word', x='similarity', dodge=False, hue='model_name', orient='h')# title='Similarity of words to \"fear\" in different models')\n",
    "    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), title='Community')\n",
    "    ax.set_title(f'Similarity to the term \"{original_word}\"')\n",
    "    # Remove y-axis label\n",
    "    ax.set_ylabel('')\n",
    "    # Make all x-axis labels display as positive using xticks\n",
    "    labels = ax.get_xticklabels()\n",
    "    ax.set_xticks(ax.get_xticks())\n",
    "    ax.set_xticklabels([x.get_text().strip('−') for x in labels])\n",
    "\n",
    "\n",
    "\n",
    "    # Save the figure\n",
    "    plt.savefig(f'../figures/similarity_{original_word}.png', bbox_inches='tight')\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [{'model': m4a_model, 'model_name': 'r/Masks4All'}, {'model': lds_model, 'model_name': 'r/LockdownSkepticism'}]\n",
    "\n",
    "make_figure(get_similar_words(models, 'fear', n=10));\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for term in ['fear', 'responsibility', 'masks', 'lockdown']:\n",
    "    make_figure(get_similar_words(models, term, n=10));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also look at how similar two terms are"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_similarity(terms, comparison_term):\n",
    "    result = []\n",
    "    models = {'m4a': m4a_model, 'lds': lds_model}\n",
    "    for term in terms:\n",
    "        for community, model in models.items():\n",
    "            result.append({'community': community,\n",
    "                          'term': term,\n",
    "                          'similarity': model.similarity(comparison_term, term)\n",
    "                         })    \n",
    "    df = pd.DataFrame(result)\n",
    "    return sns.barplot(data=df, x='term', y='similarity', hue='community').set(title=f'Similarity to the word \"{comparison_term}\"');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_similarity(['masks', 'covid', 'nurses', 'children', 'sickness'], 'fear');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_similarity(['masks', 'covid', 'nurses', 'children', 'sickness'], 'protect');"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "teaching",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
